from huggingface_hub import login 
login(token="hf_token")

import os
import re
import math
import csv
import time
import collections
import hashlib
import random

import torch
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from tqdm import tqdm

# -------------------------------
# Config (placeholders: set via env)
# -------------------------------
HF_TOKEN = os.environ.get("HF_TOKEN", "to enter values of HF_TOKEN")
MODEL_NAME = os.environ.get("MODEL_NAME", "to enter values of MODEL_NAME")

# LoRA directories (placeholders) -- 9 LoRA paths for this single pipeline
TENSORS_DIR_FULL_MLPATT = os.environ.get("TENSORS_DIR_FULL_MLPATT", "to enter values of TENSORS_DIR_FULL_MLPATT")
TENSORS_DIR_FULL_ATT    = os.environ.get("TENSORS_DIR_FULL_ATT",    "to enter values of TENSORS_DIR_FULL_ATT")
TENSORS_DIR_FULL_MLP    = os.environ.get("TENSORS_DIR_FULL_MLP",    "to enter values of TENSORS_DIR_FULL_MLP")

TENSORS_DIR_TOP1_MLPATT = os.environ.get("TENSORS_DIR_TOP1_MLPATT", "to enter values of TENSORS_DIR_TOP1_MLPATT")
TENSORS_DIR_TOP1_ATT    = os.environ.get("TENSORS_DIR_TOP1_ATT",    "to enter values of TENSORS_DIR_TOP1_ATT")
TENSORS_DIR_TOP1_MLP    = os.environ.get("TENSORS_DIR_TOP1_MLP",    "to enter values of TENSORS_DIR_TOP1_MLP")

TENSORS_DIR_TOP3_MLPATT = os.environ.get("TENSORS_DIR_TOP3_MLPATT", "to enter values of TENSORS_DIR_TOP3_MLPATT")
TENSORS_DIR_TOP3_ATT    = os.environ.get("TENSORS_DIR_TOP3_ATT",    "to enter values of TENSORS_DIR_TOP3_ATT")
TENSORS_DIR_TOP3_MLP    = os.environ.get("TENSORS_DIR_TOP3_MLP",    "to enter values of TENSORS_DIR_TOP3_MLP")

# Hyperparameters placeholders (must be set to integers via env)
def _parse_required_int_env(name):
    v = os.environ.get(name, f"to enter values of {name}")
    if isinstance(v, str) and v.strip().startswith("to enter"):
        raise RuntimeError(f"Environment variable {name} must be set to an integer. Current value is a placeholder: {v!r}")
    try:
        return int(v)
    except Exception:
        raise RuntimeError(f"Environment variable {name} must be an integer. Got: {v!r}")

N = _parse_required_int_env("N_SAMPLES")
NUM_K_SAMPLING = _parse_required_int_env("NUM_K_SAMPLING")
BATCH_SIZE = _parse_required_int_env("BATCH_SIZE")
MAX_NEW_TOKENS = _parse_required_int_env("MAX_NEW_TOKENS")
PREVIEW = _parse_required_int_env("PREVIEW")

# Other settings
EXEC_TIMEOUT = int(os.environ.get("EXEC_TIMEOUT", "8"))
CSV_OUT = os.environ.get("CSV_OUT", "to enter values of CSV_OUT")
LOG_DIR = os.environ.get("LOG_DIR", "to enter values of LOG_DIR")

random_seed_env = os.environ.get("RANDOM_SEED", None)
if random_seed_env:
    try:
        random.seed(int(random_seed_env))
    except Exception:
        pass

# create LOG_DIR if provided
if LOG_DIR and not LOG_DIR.startswith("to enter"):
    os.makedirs(LOG_DIR, exist_ok=True)

# -------------------------------
# Tokenization / device caches (shared across adapters)
# -------------------------------
token_cache_cpu = {}
device_input_cache = {}

def _prompt_hash(s: str) -> str:
    return hashlib.sha256(s.encode("utf-8")).hexdigest()

def get_tokenized_inputs(tok, prompt: str, device: torch.device):
    ph = _prompt_hash(prompt)
    dev_key = (ph, str(device))
    if dev_key in device_input_cache:
        return device_input_cache[dev_key]
    if ph not in token_cache_cpu:
        inputs_cpu = tok(prompt, return_tensors="pt")
        if getattr(tok, "pad_token", None) is None:
            tok.pad_token = tok.eos_token
        token_cache_cpu[ph] = inputs_cpu
    inputs_cpu = token_cache_cpu[ph]
    inputs_on_device = {}
    for k, v in inputs_cpu.items():
        try:
            v_p = v.pin_memory()
            inputs_on_device[k] = v_p.to(device, non_blocking=True)
        except Exception:
            inputs_on_device[k] = v.to(device)
    device_input_cache[dev_key] = inputs_on_device
    return inputs_on_device

def _clear_token_caches():
    global token_cache_cpu, device_input_cache
    token_cache_cpu = {}
    device_input_cache = {}

# -------------------------------
# Utilities: load_pretrained + load_lora (minimal prints)
# -------------------------------
def load_pretrained(model_name, hf_token):
    tok = AutoTokenizer.from_pretrained(model_name, token=hf_token, use_fast=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        token=hf_token,
        low_cpu_mem_usage=True
    )
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    if getattr(tok, "bos_token", None) is None:
        tok.bos_token = tok.eos_token
    try:
        model.config.pad_token_id = tok.pad_token_id
        model.config.eos_token_id = tok.eos_token_id
        model.config.bos_token_id = tok.bos_token_id
    except Exception:
        pass
    return model.eval(), tok

def load_lora(base_model, hf_token, tensors_dir, r, alpha):
    """
    Same mapping/loading logic but minimal stdout:
      - prints only 'Loaded LoRA weights from: <tensors_dir>' when mapped weights applied
      - prints minimal warning if tensors_dir missing or nothing mapped
    """
    model = AutoModelForCausalLM.from_pretrained(
        base_model,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        token=hf_token,
        low_cpu_mem_usage=True
    )
    tok = AutoTokenizer.from_pretrained(base_model, token=hf_token, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    if getattr(tok, "bos_token", None) is None:
        tok.bos_token = tok.eos_token
    try:
        model.config.pad_token_id = tok.pad_token_id
        model.config.eos_token_id = tok.eos_token_id
        model.config.bos_token_id = tok.bos_token_id
    except Exception:
        pass

    cfg = LoraConfig(
        r=r, lora_alpha=alpha,
        target_modules=["q_proj","k_proj","v_proj","o_proj","down_proj","up_proj","gate_proj"],
        task_type="CAUSAL_LM"
    )
    peft_model = get_peft_model(model, cfg)

    # read .pt files if provided
    state_raw = {}
    if tensors_dir and os.path.exists(tensors_dir):
        for root, _, files in os.walk(tensors_dir):
            for f in files:
                if f.endswith(".pt"):
                    k = f.replace(".pt", "")
                    path = os.path.join(root, f)
                    try:
                        tensor = torch.load(path, map_location="cpu")
                        state_raw[k] = tensor
                    except Exception:
                        pass
    else:
        if tensors_dir and (not os.path.exists(tensors_dir)):
            print(f"Warning: tensors_dir provided but path does not exist: {tensors_dir}")
        else:
            # no tensors dir provided -> return PEFT wrapper unchanged
            pass

    # mapping heuristic (suppressed verbose diagnostics)
    target_state = peft_model.state_dict()
    target_keys = set(target_state.keys())
    mapped = {}
    suffixes = ["", ".weight", ".default.weight"]
    for raw_k, tensor in state_raw.items():
        found = False
        if raw_k in target_keys:
            mapped[raw_k] = tensor
            found = True
        else:
            for suf in suffixes:
                cand = raw_k + suf
                if cand in target_keys:
                    mapped[cand] = tensor
                    found = True
                    break
            if not found:
                for prefix in ["base_model.", "model.", ""]:
                    for suf in suffixes:
                        cand = prefix + raw_k + suf
                        if cand in target_keys:
                            mapped[cand] = tensor
                            found = True
                            break
                    if found:
                        break

    if mapped:
        peft_model.load_state_dict(mapped, strict=False)
        print(f"Loaded LoRA weights from: {tensors_dir}")
    else:
        if tensors_dir and os.path.exists(tensors_dir):
            print(f"Attempted load from {tensors_dir} (no tensors mapped).")

    peft_model.eval()
    tok.pad_token = tok.eos_token
    return peft_model, tok

# -------------------------------
# Prompt builder for FinanceQA (unchanged)
# -------------------------------
def build_finqa_prompt_from_financeqa(context, query):
    pieces = []
    if context:
        pieces.append("Context:\n" + str(context).strip())
    if query:
        pieces.append("Question:\n" + str(query).strip())
    context_str = "\n\n".join(pieces).strip()
    prompt = (
        f"{context_str}\n\n"
        "Answer: The final answer is\n\n"
        "Final Answer:"
    )
    return prompt

# Numeric parsing / comparison (unchanged)
MULTIPLIER_MAP = {
    "thousand": 1e3, "k": 1e3, "k.": 1e3,
    "million": 1e6, "m": 1e6, "m.": 1e6,
    "billion": 1e9, "bn": 1e9, "b": 1e9, "b.": 1e9,
}
def parse_number_like(s):
    if s is None:
        return None, None, None
    orig = str(s).strip()
    lines = orig.splitlines()
    first_line = lines[0].strip() if lines else ""
    pattern = re.compile(
        r"(?P<sign>[-\(])?\s*(?P<prefix>[$€£]|USD|EUR|GBP)?\s*(?P<number>-?\d{1,3}(?:,\d{3})*(?:\.\d+)?|-?\d+(?:\.\d+)?)\s*(?P<mult>\b(?:thousand|million|billion|k|m|bn|b)\b)?\s*(?P<percent>%?)",
        flags=re.IGNORECASE
    )
    m = pattern.search(first_line)
    if not m:
        m2 = re.search(r"(-?\d+(\.\d+)?)\s*%", first_line)
        if m2:
            num = float(m2.group(1))
            return m2.group(0), num, 'percent'
        return first_line, None, None
    sign = m.group('sign')
    prefix = m.group('prefix') or ""
    number_str = m.group('number')
    mult = (m.group('mult') or "").lower()
    percent = m.group('percent') or ""
    number_str_clean = number_str.replace(",", "")
    try:
        num = float(number_str_clean)
    except Exception:
        try:
            num = float(re.sub(r"[^\d.-]", "", number_str_clean))
        except Exception:
            num = None
    if sign == "(" and number_str_clean and num is not None:
        num = -abs(num)
    kind = 'plain'
    if percent == '%':
        kind = 'percent'
        return m.group(0).strip(), num, kind
    if prefix:
        kind = 'currency'
    if mult:
        mul_val = MULTIPLIER_MAP.get(mult.lower(), 1.0)
        if num is not None:
            num = num * mul_val
        if kind != 'currency' and prefix:
            kind = 'currency'
    return m.group(0).strip(), num, kind

def normalize_answer_for_comparison(ans_str):
    if ans_str is None:
        return None, None, None
    if isinstance(ans_str, (list, tuple)):
        chosen = None
        for el in ans_str:
            if el is not None and str(el).strip() != "":
                chosen = el
                break
        if chosen is None:
            chosen = ans_str[0]
        ans_str = chosen
    raw = str(ans_str)
    if raw.strip() == "":
        return None, None, ""
    lines = raw.strip().splitlines()
    raw_simple = lines[0].strip() if lines else ""
    raw_str, norm_val, kind = parse_number_like(raw_simple)
    return norm_val, kind, raw_str

def compare_numeric_answers(pred_str, gold_str, rtol=1e-3, atol=1e-2):
    pred_val, pred_kind, pred_raw = normalize_answer_for_comparison(pred_str)
    gold_val, gold_kind, gold_raw = normalize_answer_for_comparison(gold_str)
    if (pred_val is not None) and (gold_val is not None):
        diff = abs(pred_val - gold_val)
        tol = max(atol, rtol * abs(gold_val))
        return (diff <= tol), pred_val, gold_val, pred_kind, gold_kind, pred_raw, gold_raw
    if pred_raw is not None and gold_raw is not None:
        p = re.sub(r"\s+", " ", pred_raw.strip()).lower()
        g = re.sub(r"\s+", " ", gold_raw.strip()).lower()
        return (p == g), pred_val, gold_val, pred_kind, gold_kind, pred_raw, gold_raw
    return False, pred_val, gold_val, pred_kind, gold_kind, pred_raw, gold_raw

def extract_gold_from_answer_field(answer_field):
    norm_val, kind, raw = normalize_answer_for_comparison(answer_field)
    if raw and raw.strip() != "":
        return raw
    NUM_MATCH_RE = re.compile(
        r"(?P<prefix>[$€£]|USD|EUR|GBP)?\s*(?P<number>-?\d{1,3}(?:,\d{3})*(?:\.\d+)?|-?\d+(?:\.\d+)?)(?P<trail>[$%]?)\s*(?P<mult>\b(?:thousand|million|billion|k|m|bn|b)\b)?",
        flags=re.IGNORECASE
    )
    if answer_field is None:
        return ""
    s = str(answer_field).strip()
    if s == "":
        return ""
    matches = list(NUM_MATCH_RE.finditer(s))
    if matches:
        m = matches[-1]
        matched = s[m.start():m.end()].strip()
        return matched
    m2 = re.findall(r"(-?\d+(\.\d+)?)\s*%", s)
    if m2:
        last = m2[-1][0]
        return last + "%"
    return s

# -------------------------------
# Generation helpers (use device cache)
# -------------------------------
def generate_greedy(model, tok, prompt, max_new_tokens=MAX_NEW_TOKENS):
    device = next(model.parameters()).device
    inputs = get_tokenized_inputs(tok, prompt, device)
    with torch.inference_mode():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=0.0,
            top_p=1.0,
            pad_token_id=tok.eos_token_id
        )
    seq = out[0]
    in_len = inputs["input_ids"].shape[-1]
    cont_ids = seq[in_len:]
    return tok.decode(cont_ids, skip_special_tokens=True).strip()

def generate_samples(model, tok, prompt, num_samples=8, batch_size=2, max_new_tokens=1024, temperature=0.7, top_p=0.95):
    device = next(model.parameters()).device
    inputs = get_tokenized_inputs(tok, prompt, device)
    candidates = []
    num_loops = math.ceil(num_samples / batch_size)
    produced = 0
    with torch.inference_mode():
        for _ in range(num_loops):
            this_bs = min(batch_size, num_samples - produced)
            gen_conf = GenerationConfig(
                max_new_tokens = max_new_tokens,
                do_sample = True,
                temperature = temperature,
                top_p = top_p
            )
            outs = model.generate(
                **inputs,
                num_return_sequences = this_bs,
                generation_config = gen_conf,
                pad_token_id = tok.eos_token_id,
            )
            for j in range(outs.shape[0]):
                seq = outs[j]
                in_len = inputs["input_ids"].shape[-1]
                cont_ids = seq[in_len:]
                cont = tok.decode(cont_ids, skip_special_tokens=True).strip()
                candidates.append(cont)
            produced += this_bs
    return candidates

def generate_answer_self_consistency(model, tok, context, question,
                                     num_samples=16,
                                     batch_size=4,
                                     max_new_tokens=1024,
                                     temperature=0.7,
                                     top_p=0.95,
                                     no_repeat_ngram_size=3,
                                     repetition_penalty=1.1):
    prompt = build_finqa_prompt_from_financeqa(context, question)
    device = next(model.parameters()).device
    inputs = get_tokenized_inputs(tok, prompt, device)

    candidates = []
    normalized_values = []
    num_loops = math.ceil(num_samples / batch_size)
    produced = 0
    with torch.inference_mode():
        for loop_idx in range(num_loops):
            this_bs = min(batch_size, num_samples - produced)
            gen_conf = GenerationConfig(
                max_new_tokens = max_new_tokens,
                do_sample = True,
                temperature = temperature,
                top_p = top_p,
                repetition_penalty = repetition_penalty,
                no_repeat_ngram_size = no_repeat_ngram_size
            )
            outs = model.generate(
                **inputs,
                num_return_sequences = this_bs,
                generation_config = gen_conf,
                pad_token_id = tok.eos_token_id,
            )
            for j in range(outs.shape[0]):
                seq = outs[j]
                input_len = inputs["input_ids"].shape[-1]
                cont_ids = seq[input_len:]
                cont = tok.decode(cont_ids, skip_special_tokens=True).strip()
                m = re.search(r"\n\s*(Q[:\d ]|Q\d+:|Q:)", cont)
                if m:
                    cont = cont[:m.start()].strip()
                else:
                    m2 = re.search(r"\n\s*(OR\b|Stop\b)", cont, flags=re.IGNORECASE)
                    if m2:
                        cont = cont[:m2.start()].strip()
                lines = [ln for ln in cont.splitlines() if ln.strip() != ""]
                cont_first = lines[0].strip() if lines else ""
                candidates.append(cont_first)
                norm_val, kind, raw = normalize_answer_for_comparison(cont_first)
                normalized_values.append((norm_val, kind, raw))
            produced += this_bs

    num_counter = collections.Counter()
    raw_counter = collections.Counter()
    for (nv, kind, raw), cand in zip(normalized_values, candidates):
        if nv is not None:
            try:
                bucket = round(float(nv), 6)
                num_counter[bucket] += 1
            except Exception:
                pass
        else:
            raw_counter[cand.strip().lower()] += 1

    chosen_continuation = None
    chosen_norm = None
    if num_counter:
        chosen_bucket, _ = num_counter.most_common(1)[0]
        for (nv, kind, raw), cand in zip(normalized_values, candidates):
            if nv is not None and round(float(nv), 6) == chosen_bucket:
                chosen_continuation = cand
                chosen_norm = nv
                break
    else:
        if raw_counter:
            chosen_raw, _ = raw_counter.most_common(1)[0]
            for cand in candidates:
                if cand.strip().lower() == chosen_raw:
                    chosen_continuation = cand
                    break

    return chosen_continuation, candidates, chosen_norm

# -------------------------------
# Orchestration: evaluate 10 models sequentially
# -------------------------------
if __name__ == "__main__":
    # re-read HF_TOKEN/MODEL_NAME if provided
    HF_TOKEN = os.environ.get("HF_TOKEN", HF_TOKEN)
    MODEL_NAME = os.environ.get("MODEL_NAME", MODEL_NAME)

    # ensure CSV/LOG parent dirs if values set
    if CSV_OUT and not CSV_OUT.startswith("to enter"):
        os.makedirs(os.path.dirname(CSV_OUT), exist_ok=True)
    if LOG_DIR and not LOG_DIR.startswith("to enter"):
        os.makedirs(LOG_DIR, exist_ok=True)

    dataset_name_for_log = "FinanceQA"
    log_file_path = os.path.join(LOG_DIR, f"{dataset_name_for_log}_detailed_log.txt") if LOG_DIR and not LOG_DIR.startswith("to enter") else None
    try:
        log_fh = open(log_file_path, "w", encoding="utf-8") if log_file_path else None
    except Exception as e:
        print("Could not open log file for writing:", e)
        log_fh = None

    ds = load_dataset("sweatSmile/FinanceQA", split=f"test[:{N}]")
    total = len(ds)
    print(f"\nLoaded {total} samples from sweatSmile/FinanceQA (test split) - N = {N}\n")

    # prepare list of 10 models (order)
    models_order = [
        "pretrained",
        "full_mlpatt","full_att","full_mlp",
        "top1_mlpatt","top1_att","top1_mlp",
        "top3_mlpatt","top3_att","top3_mlp"
    ]

    # mapping label -> tensors_dir (for non-pretrained)
    tensors_map = {
        "full_mlpatt": TENSORS_DIR_FULL_MLPATT,
        "full_att":    TENSORS_DIR_FULL_ATT,
        "full_mlp":    TENSORS_DIR_FULL_MLP,
        "top1_mlpatt": TENSORS_DIR_TOP1_MLPATT,
        "top1_att":    TENSORS_DIR_TOP1_ATT,
        "top1_mlp":    TENSORS_DIR_TOP1_MLP,
        "top3_mlpatt": TENSORS_DIR_TOP3_MLPATT,
        "top3_att":    TENSORS_DIR_TOP3_ATT,
        "top3_mlp":    TENSORS_DIR_TOP3_MLP
    }

    # results storage for all 10 models
    results_counts = {m: 0 for m in models_order}
    results_problems = {m: total for m in models_order}
    results_rows_per_model = {m: [] for m in models_order}

    # For efficiency, we'll reuse a multiprocessing pool for any CPU-bound tasks if needed.
    # FinanceQA flow currently uses model generation and local comparison; no pool required here.

    # Evaluate each model sequentially
    for model_label in models_order:
        if model_label == "pretrained":
            print(f"\n--- Evaluating pretrained model ---\n")
            model, tok = load_pretrained(MODEL_NAME, HF_TOKEN)
        else:
            tensors_dir = tensors_map.get(model_label, "")
            print(f"\n--- Loading LoRA adapter '{model_label}' from: {tensors_dir} ---")
            if not tensors_dir or str(tensors_dir).startswith("to enter") or (not os.path.exists(tensors_dir)):
                print(f"Skipping {model_label}: tensors dir missing or placeholder ({tensors_dir}).")
                # leave metrics as zeros/nan
                continue
            model, tok = load_lora(MODEL_NAME, HF_TOKEN, tensors_dir, r=16, alpha=32)

        correct = 0
        pbar = tqdm(total=total, desc=f"{model_label}", unit="it")
        try:
            for i, ex in enumerate(ds):
                query = ex.get("QUERY", "") or ex.get("query", "") or ""
                context = ex.get("CONTEXT", "") or ex.get("context", "") or ""
                gold_field = ex.get("ANSWER", "") or ex.get("answer", "") or ""
                gold_for_compare = extract_gold_from_answer_field(gold_field)

                if NUM_K_SAMPLING > 1:
                    chosen, cand_list, norm = generate_answer_self_consistency(
                        model, tok, context, query,
                        num_samples=NUM_K_SAMPLING,
                        batch_size=BATCH_SIZE,
                        max_new_tokens=MAX_NEW_TOKENS,
                        temperature=0.7,
                        top_p=0.95
                    )
                    pred = chosen
                else:
                    prompt = build_finqa_prompt_from_financeqa(context, query)
                    chosen = generate_greedy(model, tok, prompt, max_new_tokens=MAX_NEW_TOKENS)
                    pred = chosen

                is_corr, pval, gval, pk, gk, pr, gr = compare_numeric_answers(pred, gold_for_compare)
                correct += int(is_corr)

                # limited preview logs
                if i < PREVIEW:
                    per_sample_lines = []
                    per_sample_lines.append("\n" + "="*80)
                    per_sample_lines.append(f"[{model_label}] Index {i} | Query:")
                    per_sample_lines.append(query)
                    per_sample_lines.append(f"Context (first 300 chars): {str(context)[:300]}\n")
                    per_sample_lines.append(f"Gold field raw: {gold_field}")
                    per_sample_lines.append(f"Gold extracted for compare: {gold_for_compare}\n")
                    per_sample_lines.append(f"\n--- Model {model_label} ---")
                    if NUM_K_SAMPLING > 1:
                        per_sample_lines.append("Model output (chosen): " + str(chosen))
                        per_sample_lines.append("Candidates: " + str(cand_list))
                    else:
                        per_sample_lines.append("Model output: " + str(chosen))
                    per_sample_lines.append("Interpreted Pred Norm: {} | Kind: {} | Raw: {} | Correct? {}".format(pval, pk, pr, is_corr))
                    per_sample_lines.append("="*80)
                    per_sample_text = "\n".join(per_sample_lines)
                    print(per_sample_text)
                    if log_fh:
                        try:
                            log_fh.write(per_sample_text + "\n")
                            log_fh.flush()
                        except Exception:
                            pass

                pbar.update(1)
        finally:
            pbar.close()

        results_counts[model_label] = correct
        print(f"\nModel '{model_label}' done: {correct} / {total} = {correct/total if total>0 else 0.0:.3f}")

        # unload model, clear caches
        try:
            del model, tok
            torch.cuda.empty_cache()
            _clear_token_caches()
        except Exception:
            pass

    # close log file
    if log_fh:
        try:
            log_fh.close()
        except Exception:
            pass

    # Final aggregated print (10 columns)
    print("\n\nFINAL AGGREGATED ACCURACY (ALL MODELS) — Dataset: {}  N = {}\n".format(dataset_name_for_log, total))
    headers = models_order[:]
    vals = []
    for m in models_order:
        correct = results_counts.get(m, 0)
        acc = correct / total if total > 0 else float('nan')
        vals.append(acc)
        print(f"{m}: {correct} / {total} = {acc:.3f}")

    # write CSV with 10 columns (appended)
    try:
        file_exists = os.path.exists(CSV_OUT) and (not CSV_OUT.startswith("to enter"))
        if CSV_OUT and not CSV_OUT.startswith("to enter"):
            os.makedirs(os.path.dirname(CSV_OUT), exist_ok=True)
            with open(CSV_OUT, "a", newline="", encoding="utf-8") as fh:
                writer = csv.writer(fh)
                if not file_exists:
                    header_row = ["dataset", "N"] + [f"{h}_correct" for h in headers] + [f"{h}_acc" for h in headers] + ["timestamp"]
                    writer.writerow(header_row)
                row = [dataset_name_for_log, total]
                for h in headers:
                    row.append(results_counts.get(h, 0))
                for h in headers:
                    acc = results_counts.get(h, 0) / total if total > 0 else 0.0
                    row.append(f"{acc:.4f}")
                row.append(time.strftime("%Y-%m-%d %H:%M:%S"))
                writer.writerow(row)
            print(f"Wrote CSV summary to: {CSV_OUT}")
        else:
            print("CSV_OUT placeholder left; skipping CSV write. Set CSV_OUT env var to enable.")
    except Exception as e:
        print("Could not write CSV summary:", e)

    # Create PNG accuracy table (10 columns)
    try:
        import matplotlib.pyplot as plt
        def safe_fmt(x):
            try:
                return f"{x:.4f}"
            except Exception:
                return "nan"
        row_acc = [ safe_fmt(v) for v in vals ]
        fig_w = max(8, 0.6 * len(headers))
        fig, ax = plt.subplots(figsize=(fig_w, 1.8))
        ax.axis('off')
        table_data = [headers, row_acc]
        table = ax.table(cellText=table_data, loc='center', cellLoc='center')
        table.auto_set_font_size(False)
        table.set_fontsize(10)
        table.scale(1, 1.8)
        png_path = os.path.join(LOG_DIR if LOG_DIR and not LOG_DIR.startswith("to enter") else ".", f"{dataset_name_for_log}_ten_models_accuracy_table.png")
        plt.tight_layout()
        plt.savefig(png_path, dpi=150, bbox_inches='tight')
        plt.close(fig)
        print("Saved accuracy PNG table to:", png_path)
    except Exception as e:
        print("Could not create/save PNG accuracy table:", e)
